package httpserver

import (
	"context"
	"fmt"
	"net/http"
	"strings"
	"time"

	"gitee.com/amwfhv/tge/errors"
	"gitee.com/amwfhv/tge/log"
	"gitee.com/amwfhv/tge/middleware"
	"gitee.com/amwfhv/tge/options"
	"gitee.com/amwfhv/tge/response"
	"gitee.com/amwfhv/tge/version"
	"github.com/gin-contrib/pprof"
	"github.com/gin-gonic/gin"
	ginprometheus "github.com/zsais/go-gin-prometheus"
	"golang.org/x/sync/errgroup"
)

type HttpServer struct {
	middlewares         []string
	mode                string
	SecureServingInfo   *options.SecureServingInfo
	InsecureServingInfo *options.InsecureServingInfo
	ShutdownTimeout     time.Duration
	*gin.Engine
	healthz                      bool
	enableMetrics                bool
	enableProfiling              bool
	insecureServer, secureServer *http.Server
}

func NewHttpServer(cfg *Config) *HttpServer {
	s := &HttpServer{
		middlewares:         cfg.Middlewares,
		mode:                cfg.Mode,
		healthz:             cfg.Healthz,
		enableProfiling:     cfg.EnableProfiling,
		enableMetrics:       cfg.EnableMetrics,
		InsecureServingInfo: cfg.InsecureServing,
		SecureServingInfo:   cfg.SecureServing,
		Engine:              gin.New(),
	}
	s.initHttpServer()

	return s
}

func (s *HttpServer) initHttpServer() {
	s.Setup()
	s.InstallMiddlewares()
	s.InstallAPIs()
}

func (s *HttpServer) Setup() {
	gin.SetMode(s.mode)
	gin.DebugPrintRouteFunc = func(httpMethod, absolutePath, handlerName string, nuHandlers int) {
		log.Infof("%-6s %-s --> %s (%d handlers)", httpMethod, absolutePath, handlerName, nuHandlers)
	}
}

func (s *HttpServer) InstallMiddlewares() {
	s.Use(middleware.RequestID())

	for _, m := range s.middlewares {
		mw, ok := middleware.Middlewares[m]
		if !ok {
			log.Warnf("can not find middleware: %s", m)

			continue
		}

		log.Infof("install middleware: %s", m)
		s.Use(mw)
	}
}

func (s *HttpServer) InstallAPIs() {
	if s.healthz {
		s.GET("/healthz", func(c *gin.Context) {
			response.WriteResponse(c, nil, map[string]string{"status": "ok"})
		})
	}

	if s.enableMetrics {
		prometheus := ginprometheus.NewPrometheus("gin")
		prometheus.Use(s.Engine)
	}

	if s.enableProfiling {
		pprof.Register(s.Engine)
	}

	s.GET("/version", func(c *gin.Context) {
		response.WriteResponse(c, nil, version.Get())
	})
}

func (s *HttpServer) Run() error {
	s.insecureServer = &http.Server{
		Addr:    s.InsecureServingInfo.Address(),
		Handler: s,
	}

	s.secureServer = &http.Server{
		Addr:    s.SecureServingInfo.Address(),
		Handler: s,
	}

	var eg errgroup.Group

	eg.Go(func() error {
		log.Infof("Start to listening the incoming requests on http address:%s", s.insecureServer.Addr)

		if err := s.insecureServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
			log.Fatal(err.Error())

			return err
		}
		log.Infof("Server on %s stopped", s.insecureServer.Addr)

		return nil
	})

	eg.Go(func() error {
		key, cert := s.SecureServingInfo.CertKey.KeyFile, s.SecureServingInfo.CertKey.CertFile
		if cert == "" || key == "" || s.SecureServingInfo.BindPort == 0 {
			return nil
		}

		log.Infof("Start to listening the incoming requests on https address: %s", s.secureServer.Addr)

		if err := s.secureServer.ListenAndServeTLS(cert, key); err != nil && !errors.Is(err, http.ErrServerClosed) {
			log.Fatal(err.Error())

			return err
		}

		log.Infof("Server on %s stopped", s.secureServer.Addr)

		return nil
	})

	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
	defer cancel()
	if s.healthz {
		if err := s.ping(ctx); err != nil {
			return err
		}
	}

	if err := eg.Wait(); err != nil {
		log.Fatal(err.Error())
	}

	return nil
}

func (s *HttpServer) Close() {
	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
	defer cancel()
	if err := s.secureServer.Shutdown(ctx); err != nil {
		log.Warnf("Shutdown secure server failed: %s", err.Error())
	}

	if err := s.insecureServer.Shutdown(ctx); err != nil {
		log.Warnf("Shutdown insecure server failed: %s", err.Error())
	}
}

func (s *HttpServer) ping(ctx context.Context) error {
	url := fmt.Sprintf("http://%s/healthz", s.insecureServer.Addr)
	if strings.Contains(s.insecureServer.Addr, "0.0.0.0") {
		url = fmt.Sprintf("http://127.0.0.1:%s/healthz", strings.Split(s.insecureServer.Addr, ":")[1])
	}

	for {
		req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
		if err != nil {
			return err
		}

		resp, err := http.DefaultClient.Do(req)
		if err == nil && resp.StatusCode == http.StatusOK {
			log.Info("The router has been deployed successfully.")

			resp.Body.Close()

			return nil
		}

		log.Info("Waiting for the router, retry in 1 second.")
		time.Sleep(time.Second)

		select {
		case <-ctx.Done():
			log.Fatal("can not ping http server within the specified time interval.")
		default:
		}
	}
}
